import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from sklearn.metrics.cluster import mutual_info_score
import math
from tqdm import tqdm

class MarkovBlanketSimulator:
    """Simulation of emergence of Markov blankets via attractor dynamics"""
    
    def __init__(self, 
                dim_rep=10,      # Dimension of the representation space
                dim_s=3,         # Dimension of subsystem S
                dim_b=4,         # Dimension of Markov blanket B
                dim_e=3,         # Dimension of environment E
                dt=0.05,         # Time step !0.05
                T=500,           # Total simulation time
                gamma_min=0.5,   # Minimum metabolic damping !0.5
                noise_amp=1.6):  # Noise amplitude !1.6
        
        self.dim_rep = dim_rep
        self.dim_s = dim_s
        self.dim_b = dim_b
        self.dim_e = dim_e
        assert dim_s + dim_b + dim_e == dim_rep, "Dimensions must sum to dim_rep"
        
        self.dt = dt
        self.T = T
        self.steps = int(T/dt)
        self.gamma_min = gamma_min
        self.noise_amp = noise_amp
        
        # Initialize representation
        self.rep_history = np.zeros((self.steps, dim_rep))
        # self.rep_history[0] = np.zeros(dim_rep) # Start at origin instead of random
        
        # Create attractor landscape with nested basins
        # We'll use a mixture of Gaussians for the potential
        self.n_attractors = 3
        self.attractors = np.zeros((self.n_attractors, dim_rep))
        # Define attractors with balanced amplitudes
        for i in range(self.n_attractors):
            self.attractors[i] = np.sin(np.linspace(0, 2*np.pi, dim_rep) + i*2*np.pi/self.n_attractors)
            # self.attractors[i] = 0.9 * np.sin(np.linspace(0, 2*np.pi, dim_rep) + i*2*np.pi/self.n_attractors)  # Scaled amplitude

        # self.attractor_widths = 2.0 * np.ones(self.n_attractors)  # Uniform widths
        self.attractor_widths = np.random.uniform(1.0, 3.5, self.n_attractors)
        
        # Create coupling matrices for interacting components
        # These will determine how the subsystem, blanket and environment interact
        self.SB_coupling = self.create_coupling_matrix(dim_s, dim_b, strength=0.7)
        self.BE_coupling = self.create_coupling_matrix(dim_b, dim_e, strength=0.5)
        self.SE_coupling = self.create_coupling_matrix(dim_s, dim_e, strength=0.2)  # Weak direct coupling
        # self.SE_coupling = self.create_coupling_matrix(dim_s, dim_e, strength=0.3)  # Increased coupling strength
        
        # Memory, sensory, and emotional inputs
        self.memory_trace = np.zeros((self.steps, dim_rep))
        self.sensory_inputs = np.zeros((self.steps, dim_rep))
        self.emotional_state = np.zeros((self.steps, dim_rep))
        
        # Tracking statistics
        self.mi_S_E = np.zeros(self.steps)  # Mutual information between S and E
        self.mi_S_E_given_B = np.zeros(self.steps)  # Conditional MI
        self.mi_S_B = np.zeros(self.steps)  # MI between S and B
        self.mi_B_E = np.zeros(self.steps)  # MI between B and E
        self.info_bottleneck = np.zeros(self.steps)  # Information bottleneck objective
        self.curvature_metrics = np.zeros(self.steps)  # Curvature at B
        
    def create_coupling_matrix(self, dim1, dim2, strength=1.0):
        """Create a coupling matrix with specified strength"""
        matrix = np.random.randn(dim1, dim2) * strength
        return matrix
    
    def potential_function(self, rep):
        """Multi-well potential function creating the Lyapunov landscape"""
        potential = 0
        for i in range(self.n_attractors):
            dist = np.sum((rep - self.attractors[i])**2) / self.attractor_widths[i]
            potential += -np.exp(-dist)
        
        # Add coupling terms that create dependencies between S, B, and E
        s = rep[:self.dim_s]
        b = rep[self.dim_s:self.dim_s+self.dim_b]
        e = rep[self.dim_s+self.dim_b:]
        
        # These coupling terms create the conditional independence structure
        coupling_potential = -np.sum(s @ self.SB_coupling @ b.T)
        coupling_potential += -np.sum(b @ self.BE_coupling @ e.T)
        coupling_potential += -np.sum(s @ self.SE_coupling @ e.T)
        
        return -potential + 0.01 * coupling_potential
    
    def gradient_potential(self, rep):
        """Compute gradient of the potential function"""
        # eps = 1e-6
        eps = 1e-4 # Smoother derivatives
        grad = np.zeros_like(rep)
        
        for i in range(len(rep)):
            rep_plus = rep.copy()
            rep_plus[i] += eps
            rep_minus = rep.copy()
            rep_minus[i] -= eps
            
            grad[i] = (self.potential_function(rep_plus) - 
                      self.potential_function(rep_minus)) / (2*eps)
    
        return grad
    
    def hessian_potential(self, rep):
        """Compute Hessian matrix (second derivatives) of potential at point rep"""
        eps = 1e-4
        dim = len(rep)
        hessian = np.zeros((dim, dim))
        
        for i in range(dim):
            for j in range(dim):
                # Compute mixed partial derivatives using finite differences
                rep_pp = rep.copy()
                rep_pp[i] += eps
                rep_pp[j] += eps
                
                rep_pm = rep.copy()
                rep_pm[i] += eps
                rep_pm[j] -= eps
                
                rep_mp = rep.copy()
                rep_mp[i] -= eps
                rep_mp[j] += eps
                
                rep_mm = rep.copy()
                rep_mm[i] -= eps
                rep_mm[j] -= eps
                
                hessian[i, j] = (self.potential_function(rep_pp) - 
                                self.potential_function(rep_pm) -
                                self.potential_function(rep_mp) + 
                                self.potential_function(rep_mm)) / (4 * eps * eps)
                
        return hessian
    
    def metabolic_damping(self, t):
        """Metabolic damping term with oscillatory component to model variations"""
        base_damping = self.gamma_min
        oscillation = 0.05 * np.sin(0.1 * t)
        return base_damping + max(0, oscillation)
    
    def update_memory(self, t):
        """Update memory trace with exponential decay"""
        if t == 0:
            self.memory_trace[t] = self.rep_history[t]
        else:
            decay = 0.9
            self.memory_trace[t] = decay * self.memory_trace[t-1] + (1-decay) * self.rep_history[t-1]
    
    def generate_sensory_input(self, t):
        """Generate sensory input with some temporal coherence"""
        if t == 0:
            self.sensory_inputs[t] = np.random.randn(self.dim_rep) * 0.1
        else:
            coherence = 0.8
            noise = np.random.randn(self.dim_rep) * 0.1
            self.sensory_inputs[t] = coherence * self.sensory_inputs[t-1] + (1-coherence) * noise
            
            # Make sure sensory input primarily affects environment and blanket components
            mask = np.ones(self.dim_rep)
            mask[:self.dim_s] *= 0.1  # Reduce direct influence on subsystem
            self.sensory_inputs[t] *= mask
    
    def update_emotions(self, t, rep):
        """Update emotional state based on current representation and deviation from homeostasis"""
        if t == 0:
            self.emotional_state[t] = np.zeros(self.dim_rep)
        else:
            # Homeostatic setpoint (arbitrary for this simulation)
            setpoint = np.zeros(self.dim_rep)
            
            # Deviation from homeostasis drives emotional response
            deviation = rep - setpoint
            
            # Emotion primarily affects subsystem and blanket
            response = np.zeros(self.dim_rep)
            response[:self.dim_s+self.dim_b] = deviation[:self.dim_s+self.dim_b] * 0.2
            
            # Temporal smoothing
            if t > 0:
                self.emotional_state[t] = 0.9 * self.emotional_state[t-1] + 0.1 * response
            else:
                self.emotional_state[t] = response
    
    def recursive_operator(self, t, rep_prev, m, s, e):
        """Recursion operator from Lemma 1"""
        gamma = self.metabolic_damping(t * self.dt)
        
        # Compute gradient of potential
        grad_V = self.gradient_potential(rep_prev)
        
        # Integrate influences: memories, sensory inputs, emotions, and previous representation
        new_rep = rep_prev - self.dt * grad_V + self.dt * (m + s + e - gamma * rep_prev)
        
        # Add noise term
        noise = np.random.randn(self.dim_rep) * np.sqrt(self.dt) * self.noise_amp
        new_rep += noise

        # Cap magnitude
        norm = np.linalg.norm(new_rep)
        if norm > 10:
            new_rep = new_rep / norm * 10
        
        return new_rep
    
    def simulate(self):
        """Run the main simulation"""
        print("Starting Markov Blanket emergence simulation...")
        print(f"System dimensions: S={self.dim_s}, B={self.dim_b}, E={self.dim_e}")
        print(f"Running for {self.T} time units with dt={self.dt}")
        
        for t in tqdm(range(1, self.steps)):
            # Update memory trace
            self.update_memory(t-1)
            
            # Generate sensory input
            self.generate_sensory_input(t-1)
            
            # Update emotional state
            self.update_emotions(t-1, self.rep_history[t-1])
            
            # Apply recursive operator
            self.rep_history[t] = self.recursive_operator(
                t, 
                self.rep_history[t-1],
                self.memory_trace[t-1],
                self.sensory_inputs[t-1],
                self.emotional_state[t-1]
            )
            
            # Periodically compute information-theoretic measures
            if t % 20 == 0:
                self.compute_information_metrics(t)
        
        # Final computation of metrics
        self.compute_final_metrics()
        self.report_results()
    
    def compute_information_metrics(self, t):
        """Compute information-theoretic metrics at the current timestep"""
        # Extract recent history for estimating distributions
        t_window = max(0, t - 100)
        history = self.rep_history[t_window:t+1]
        
        # Extract subsystem, blanket, and environment components
        S = history[:, :self.dim_s]
        B = history[:, self.dim_s:self.dim_s+self.dim_b]
        E = history[:, self.dim_s+self.dim_b:]
        
        # Discretize data for MI estimation (simple binning approach)
        bins = 20
        S_binned = np.digitize(S, bins=np.linspace(S.min(), S.max(), bins))
        B_binned = np.digitize(B, bins=np.linspace(B.min(), B.max(), bins))
        E_binned = np.digitize(E, bins=np.linspace(E.min(), E.max(), bins))
        
        # Flatten to 1D for mutual information calculation
        S_flat = np.sum(S_binned * np.array([bins**i for i in range(self.dim_s)]), axis=1)
        B_flat = np.sum(B_binned * np.array([bins**i for i in range(self.dim_b)]), axis=1)
        E_flat = np.sum(E_binned * np.array([bins**i for i in range(self.dim_e)]), axis=1)
        
        # Compute mutual information
        self.mi_S_B[t] = mutual_info_score(S_flat, B_flat)
        self.mi_B_E[t] = mutual_info_score(B_flat, E_flat)
        self.mi_S_E[t] = mutual_info_score(S_flat, E_flat)
        
        # Approximate conditional MI: I(S;E|B) ≈ I(S;E) - I(S;B) - I(B;E) + I(S;B;E)
        # For simplicity, we'll use an approximation: I(S;E|B) ≈ I(S;E) - min(I(S;B), I(B;E))
        self.mi_S_E_given_B[t] = max(0, self.mi_S_E[t] - min(self.mi_S_B[t], self.mi_B_E[t]))
        
        # Information bottleneck objective
        beta = 0.5  # Trade-off parameter
        self.info_bottleneck[t] = self.mi_S_B[t] - beta * self.mi_B_E[t]
        
        # Compute curvature at the blanket
        rep_now = self.rep_history[t]
        hessian = self.hessian_potential(rep_now)
        blanket_hessian = hessian[self.dim_s:self.dim_s+self.dim_b, self.dim_s:self.dim_s+self.dim_b]
        
        # Extract eigenvalues to assess curvature
        try:
            eigenvalues = np.linalg.eigvals(blanket_hessian)
            self.curvature_metrics[t] = np.mean(eigenvalues)
        except np.linalg.LinAlgError:
            self.curvature_metrics[t] = np.nan
    
    def compute_final_metrics(self):
        """Compute final metrics over the entire simulation"""
        # Compute convergence of representation
        final_window = 100
        self.avg_mi_S_B = np.mean(self.mi_S_B[-final_window:])
        self.avg_mi_B_E = np.mean(self.mi_B_E[-final_window:])
        self.initial_variance = np.var(self.rep_history[:final_window], axis=0).mean()
        self.final_variance = np.var(self.rep_history[-final_window:], axis=0).mean()
        
        self.convergence_ratio = self.final_variance / self.initial_variance
        
        # Compute average metrics over last part of simulation
        self.avg_mi_S_E = np.mean(self.mi_S_E[-final_window:])
        self.avg_mi_S_E_given_B = np.mean(self.mi_S_E_given_B[-final_window:])
        self.avg_info_bottleneck = np.mean(self.info_bottleneck[-final_window:])
        self.avg_curvature = np.mean(self.curvature_metrics[-final_window:])
        
        # Determine if we have strong or weak blanket based on curvature
        if np.abs(self.avg_curvature) < 0.1:
            self.blanket_type = "Weak Blanket"
        elif self.avg_curvature > 0.1:
            self.blanket_type = "Strong Blanket"
        else:
            self.blanket_type = "Undefined"
    
    def report_results(self):
        """Report simulation results"""
        print("\n# SIMULATION RESULTS")
        
        # Report convergence
        print("\nConvergence Analysis:")
        print(f"Initial representation variance: {self.initial_variance:.6f}")
        print(f"Final representation variance: {self.final_variance:.6f}")
        print(f"Convergence ratio: {self.convergence_ratio:.6f}")
        if self.convergence_ratio < 0.5:
            print("✓ Representation converged to attractor basins")
        else:
            print("✗ Representation did not fully converge")
            
        # Report information-theoretic metrics
        print("\nInformation-Theoretic Metrics:")
        print(f"Mutual Information (S;B): {self.avg_mi_S_B:.4f}")
        print(f"Mutual Information (B;E): {self.avg_mi_B_E:.4f}")
        print(f"Mutual Information (S;E): {self.avg_mi_S_E:.4f}")
        print(f"Conditional MI (S;E|B): {self.avg_mi_S_E_given_B:.4f}")
        
        # Check if B acts as a Markov blanket
        markov_blanket_threshold = 0.2
        if self.avg_mi_S_E_given_B < markov_blanket_threshold:
            print("✓ B acts as a Markov blanket between S and E")
            print(f"   Conditional independence: I(S;E|B) = {self.avg_mi_S_E_given_B:.4f} < {markov_blanket_threshold}")
        else:
            print("✗ B does not fully separate S and E")
            
        # Information bottleneck results
        print(f"\nInformation Bottleneck Objective: {self.avg_info_bottleneck:.4f}")
        if self.avg_info_bottleneck > 0:
            print("✓ B optimizes information flow between S and E")
        else:
            print("✗ B is not an optimal information bottleneck")
            
        # Geometry of the blanket
        print(f"\nBlanket Geometry Analysis:")
        print(f"Average curvature at B: {self.avg_curvature:.4f}")
        print(f"Blanket type: {self.blanket_type}")
        
        # Metabolic damping effect
        gamma_values = [self.metabolic_damping(t * self.dt) for t in range(self.steps)]
        print(f"\nMetabolic Damping Effect:")
        print(f"Min damping: {min(gamma_values):.4f}")
        print(f"Max damping: {max(gamma_values):.4f}")
        print(f"Average damping: {np.mean(gamma_values):.4f}")
        
        # Theorem II.A verification
        # print("\n" + "="*50)
        print("\n# THEOREM II.A VERIFICATION")
        print("\nThe simulation tests the following claims from Theorem II.A:")
        
        # Test 1: Conditional Independence
        test1 = self.avg_mi_S_E_given_B < markov_blanket_threshold
        print(f"1. Conditional Independence: I(S;E|B) ≈ 0")
        print(f"   Result: {'✓ VERIFIED' if test1 else '✗ NOT VERIFIED'}")
        print(f"   Value: {self.avg_mi_S_E_given_B:.4f}")
        
        # Test 2: Information Bottleneck Optimality
        test2 = self.avg_info_bottleneck > 0
        print(f"2. Information Bottleneck Optimality")
        print(f"   Result: {'✓ VERIFIED' if test2 else '✗ NOT VERIFIED'}")
        print(f"   Value: {self.avg_info_bottleneck:.4f}")
        
        # Test 3: Curvature Properties
        test3a = np.abs(self.avg_curvature) < 0.1  # Weak blanket
        test3b = self.avg_curvature > 0.1          # Strong blanket
        print(f"3. Curvature Properties at Markov Blanket")
        print(f"   Result: {'✓ VERIFIED' if (test3a or test3b) else '✗ NOT VERIFIED'}")
        print(f"   Value: {self.avg_curvature:.4f}")
        print(f"   Type: {self.blanket_type}")
        
        # Overall assessment
        if all([test1, test2, (test3a or test3b)]):
            print("\nCONCLUSION: All conditions of Theorem II.A are VERIFIED")
        else:
            print("\nCONCLUSION: Some conditions of Theorem II.A are NOT VERIFIED")

# Additional analysis functions
def analyze_trajectories(simulator):
    """Analyze representation trajectories and basin formations"""
    # Compute pairwise distances between states over time
    distance_matrix = np.zeros((simulator.steps, simulator.steps))
    
    print("\nAnalyzing state space trajectories...")
    for i in tqdm(range(0, simulator.steps, 10)):  # Sample for efficiency
        for j in range(0, simulator.steps, 10):
            distance_matrix[i, j] = np.linalg.norm(
                simulator.rep_history[i] - simulator.rep_history[j]
            )
    
    # Identify clusters/attractors
    from sklearn.cluster import KMeans
    n_clusters = simulator.n_attractors
    
    # Use last half of simulation for clustering (after transients)
    half_point = simulator.steps // 2
    kmeans = KMeans(n_clusters=n_clusters).fit(simulator.rep_history[half_point:])
    
    # Compute distance to cluster centers
    cluster_distances = []
    for i in range(n_clusters):
        mask = (kmeans.labels_ == i)
        points = simulator.rep_history[half_point:][mask]
        mean_dist = np.mean([np.linalg.norm(p - kmeans.cluster_centers_[i]) for p in points])
        cluster_distances.append(mean_dist)
    
    print("\nAttractor Basin Analysis:")
    print(f"Number of attractors identified: {n_clusters}")
    for i in range(n_clusters):
        count = np.sum(kmeans.labels_ == i)
        percentage = count / len(kmeans.labels_) * 100
        print(f"Attractor {i+1}:")
        print(f"  - Points in basin: {count} ({percentage:.1f}%)")
        print(f"  - Mean distance to center: {cluster_distances[i]:.4f}")
    
    # Analyze basin transitions
    if simulator.steps > 1000:
        transitions = 0
        window = 100
        for t in range(half_point, simulator.steps - window, window):
            start_cluster = kmeans.predict([simulator.rep_history[t]])[0]
            end_cluster = kmeans.predict([simulator.rep_history[t+window]])[0]
            if start_cluster != end_cluster:
                transitions += 1
        
        print(f"\nBasin Transitions Analysis:")
        print(f"Number of transitions between attractors: {transitions}")
        if transitions > 0:
            print("✓ System exhibits metastability with transitions between attractors")
        else:
            print("✗ System remains in single attractor basin")
    
    return distance_matrix, kmeans

def run_simulation_experiment():
    """Run the full simulation experiment with analysis"""
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Create and run simulator
    sim = MarkovBlanketSimulator()
    sim.simulate()
    
    # Additional analysis
    distance_matrix, kmeans = analyze_trajectories(sim)
    
    # Print summary
    print("\n# SIMULATION SUMMARY")
    
    print("\nKey Findings:")
    
    # Finding 1: Markov Blanket Emergence
    if sim.avg_mi_S_E_given_B < 0.2:
        print("1. ✓ Markov blanket emerged between subsystem and environment")
        print(f"   The blanket effectively shields the subsystem from direct")
        print(f"   influences of the environment, with conditional mutual")
        print(f"   information I(S;E|B) = {sim.avg_mi_S_E_given_B:.4f}")
    else:
        print("1. ✗ Markov blanket did not fully emerge")

    attractor_counts = np.array([np.sum(kmeans.labels_ == i) for i in range(sim.n_attractors)])
    mean_count = np.mean(attractor_counts)
    coefficient_of_variation = np.std(attractor_counts) / mean_count
    if coefficient_of_variation < 0.5:  # Allow 50% variation relative to mean    
    # # Finding 2: Attractor Dynamics
    # ideal_points_per_basin = len(sim.rep_history) / sim.n_attractors
    # threshold = ideal_points_per_basin * 0.5 # Allow variance (50%)
    # attractor_variance = np.var([np.sum(kmeans.labels_ == i) for i in range(sim.n_attractors)])
    # if attractor_variance < threshold:
        print("2. ✓ Multiple stable attractors formed with balanced occupancy")
    else:
        print("2. ✗ Attractor occupancy was imbalanced")
        print(f"   Variance: {attractor_variance:.2f}, Threshold: {threshold:.2f}")

    # Finding 3: Information Bottleneck
    if sim.avg_info_bottleneck > 0:
        print("3. ✓ Blanket optimizes the information bottleneck objective")
        print(f"   It retains information about S while compressing information")
        print(f"   about E, with bottleneck score = {sim.avg_info_bottleneck:.4f}")
    else:
        print("3. ✗ Blanket does not optimize information flow")
    
    # Finding 4: Metabolic Constraints
    damping_effect = np.mean([sim.metabolic_damping(t * sim.dt) for t in range(sim.steps)])
    if damping_effect > sim.gamma_min:
        print("4. ✓ Metabolic damping constrains representation dynamics")
        print(f"   Average damping = {damping_effect:.4f}, ensuring system stability")
    else:
        print("4. ✗ Insufficient metabolic damping")
    
    # Final conclusion
    print("\nOverall Conclusion:")
    if all([sim.avg_mi_S_E_given_B < 0.2, 
            coefficient_of_variation < 0.5, 
            sim.avg_info_bottleneck > 0, 
            damping_effect > sim.gamma_min]):
        print("The simulation CONFIRMS Theorem II.A: Markov blankets emerge")
        print("through attractor dynamics in the Lyapunov landscape, with")
        print("the blanket forming at the boundary between internal and")
        print("external states of the system.")
    else:
        print("Some aspects of Theorem II.A were NOT CONFIRMED by the simulation.")
        print("Further parameter tuning or model refinement may be needed.")

    plot_combined_landscapes(sim, kmeans)


def plot_combined_landscapes(simulator, kmeans=None):
    """Combine all three landscape plots into a 3x1 grid."""
    from sklearn.decomposition import PCA
    from mpl_toolkits.mplot3d import Axes3D

    scale_factor = 0.5
    plt.rcParams.update({'font.size': 6.5}) 
    fig = plt.figure(figsize=(8.35, 2.55))  # Keep the width/height ratio

    # Plot 1: Attractor Landscape (2D PCA)
    ax1 = fig.add_subplot(1, 3, 1)
    pca_2d = PCA(n_components=2)
    rep_pca_2d = pca_2d.fit_transform(simulator.rep_history)

    # Plot trajectory colored by time
    scatter1 = ax1.scatter(rep_pca_2d[:, 0], rep_pca_2d[:, 1], c=np.arange(simulator.steps),
                           cmap='viridis', alpha=0.5, s=1 * (scale_factor/4.5))
    ax1.scatter(rep_pca_2d[0, 0], rep_pca_2d[0, 1], c='red', s=100 * (scale_factor/4.5), marker='*', label='Start')
    ax1.scatter(rep_pca_2d[-1, 0], rep_pca_2d[-1, 1], c='green', s=100 * (scale_factor/4.5), marker='*', label='End')

    if kmeans is not None:
        centers_pca_2d = pca_2d.transform(kmeans.cluster_centers_)
        ax1.scatter(centers_pca_2d[:, 0], centers_pca_2d[:, 1], c='red', s=100 * (scale_factor/4.5), marker='X', label='Attractors')

    fig.colorbar(scatter1, ax=ax1, label='Time Step', fraction=0.05, pad=0.265)
    ax1.set_title('   Attractor Landscape Trajectory (PCA)')
    ax1.set_xlabel('PCA Component 1')
    ax1.set_ylabel('PCA Component 2')
    ax1.legend()

    # Plot 2: Potential Landscape (2D PCA with Contours) yellow
    ax2 = fig.add_subplot(1, 3, 2)
    pca_2d = PCA(n_components=2)  # Recompute PCA for consistency
    rep_pca_2d = pca_2d.fit_transform(simulator.rep_history)

    # Create a grid in PCA space
    grid_points = 50
    x = np.linspace(min(rep_pca_2d[:, 0])-0.5, max(rep_pca_2d[:, 0])+0.5, grid_points)
    y = np.linspace(min(rep_pca_2d[:, 1])-0.5, max(rep_pca_2d[:, 1])+0.5, grid_points)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros((grid_points, grid_points))

    # Compute mean representation
    mean_rep = np.mean(simulator.rep_history, axis=0)

    # Compute potential landscape
    for i in range(grid_points):
        for j in range(grid_points):
            pca_coords = np.array([X[i, j], Y[i, j]])
            orig_coords = mean_rep + pca_coords[0] * pca_2d.components_[0] + pca_coords[1] * pca_2d.components_[1]
            Z[i, j] = simulator.potential_function(orig_coords)

    # Plot with contours
    contour = ax2.contourf(X, Y, Z, 20, cmap='managua')

    # Overlay trajectory
    ax2.scatter(rep_pca_2d[:, 0], rep_pca_2d[:, 1], c='white', alpha=0.3, s=2 * (scale_factor/4.5))

    ax2.set_title(' Potential Landscape with Trajectory')
    ax2.set_xlabel('PCA Component 1')
    ax2.set_ylabel('PCA Component 2')

    # Plot 3: 3D Attractor Landscape yellow
    ax3 = fig.add_subplot(1, 3, 3, projection='3d')
    ax3.tick_params(axis='both', which='major', labelsize=4.7 * scale_factor, pad=-3.5)
    pca_3d = PCA(n_components=3)
    rep_pca_3d = pca_3d.fit_transform(simulator.rep_history)

    # Plot trajectory with color gradient by time
    points = ax3.scatter(rep_pca_3d[:, 0], rep_pca_3d[:, 1], rep_pca_3d[:, 2],
                         c=np.arange(simulator.steps), cmap='viridis',
                         alpha=0.2, s=0.5 * (scale_factor/4.5))

    # Mark start and end points
    ax3.scatter(rep_pca_3d[0, 0], rep_pca_3d[0, 1], rep_pca_3d[0, 2],
                color='red', s=32 * (scale_factor/1.1), marker='*', label='Start')
    ax3.scatter(rep_pca_3d[-1, 0], rep_pca_3d[-1, 1], rep_pca_3d[-1, 2],
                color='green', s=32 * (scale_factor/1.1), marker='*', label='End')

    if kmeans is not None:
        centers_pca_3d = pca_3d.transform(kmeans.cluster_centers_)
        ax3.scatter(centers_pca_3d[:, 0], centers_pca_3d[:, 1], centers_pca_3d[:, 2],
                    color='red', s=35 * scale_factor, marker='X', label='Attractors')

    # ----- Start of Potential Landscape Block -----
    # Compute potential landscape in the first two PCA dimensions
    grid_points = 30  # Reduced for faster computation
    rep_pca = rep_pca_3d  # Use the first trajectory for the potential landscape
    x = np.linspace(min(rep_pca[:, 0])-0.5, max(rep_pca[:, 0])+0.5, grid_points)
    y = np.linspace(min(rep_pca[:, 1])-0.5, max(rep_pca[:, 1])+0.5, grid_points)
    X, Y = np.meshgrid(x, y)
    Z_potential = np.zeros((grid_points, grid_points))

    # Compute mean representation
    mean_rep = np.mean(simulator.rep_history, axis=0)

    # Compute potential on the grid
    for i in range(grid_points):
        for j in range(grid_points):
            pca_coords = np.array([X[i, j], Y[i, j], 0])  # Third component set to 0
            orig_coords = mean_rep + pca_coords[0] * pca_3d.components_[0] + pca_coords[1] * pca_3d.components_[1]
            Z_potential[i, j] = simulator.potential_function(orig_coords)

    # Place contours at the bottom of the plot
    z_min = np.min(rep_pca_3d[:, 2])  # Minimum z across all trajectories
    # Plot contours at z_min
    ax3.contour3D(X, Y, Z_potential, levels=10, cmap='managua', alpha=0.5, zdir='z', offset=z_min, linewidths=2 * (scale_factor))
    # ----- End of Potential Landscape Block -----

    ax3.set_title('        3D Attractor Landscape (PCA)')
    ax3.set_xlabel('PCA Component 1', labelpad=-10.5)
    ax3.set_ylabel('PCA Component 2', labelpad=-10.5)
    ax3.set_zlabel('PCA Component 3', labelpad=-10.5)
    # Enable border-like appearance by modifying the axes properties
    ax3.xaxis.pane.fill = False
    ax3.yaxis.pane.fill = False
    ax3.zaxis.pane.fill = False
    ax3.legend()
    ax3.set_zlim(-7.5, 7.5)

    # Adjust subplot spacing to make gaps consistent and align Plot 3
    plt.subplots_adjust(left=0.05, right=0.97, wspace=0.30, top=0.85, bottom=0.15)

    # Fine-tune the position of ax3 to match the height and reduce rightward shift
    pos1 = ax1.get_position()
    pos2 = ax2.get_position()
    pos3 = ax3.get_position()
    ax1.set_position([pos1.x0 + 0.009, pos1.y0, pos1.width + 0.056, pos1.height + 0.03])  
    ax2.set_position([pos2.x0 + 0.0515, pos2.y0, pos2.width - 0.005, pos2.height + 0.03])  
    ax3.set_position([pos3.x0 - 0.080, pos3.y0 - 0.1251, pos3.width + 0.16, pos1.height + 0.16]) 

    plt.savefig('plots/T.II.A_sim_emergence_of_markov_blankets.png', dpi=300)
    plt.show()

# Run the full experiment
if __name__ == "__main__":
    run_simulation_experiment()
